Skip to content

Conversation

yyttt6
Copy link
Contributor

@yyttt6 yyttt6 commented Sep 26, 2025

Summary by CodeRabbit

  • New Features

    • Planner-driven atomic-add vectorization that chooses vector widths by device capability and reports computed plans.
    • Lazy layout inference during lowering, producing fused parallel loops with optional predicate-guarded vectorized paths.
    • Public InferLayout API added for atomic-add nodes to support dynamic layout decisions.
  • Refactor

    • Consolidated planning and rewriting into a two-stage planner→rewriter pipeline with unified dynamic handling and a simpler entrypoint.
  • Breaking Changes

    • Public vectorization entrypoint and plan/result surface simplified; older multi-parameter API removed.

Copy link
Contributor

coderabbitai bot commented Sep 26, 2025

Walkthrough

Refactors AtomicAdd vectorization into a planner-driven pipeline (planner + plan result + rewriter), simplifies the VectorizeAtomicAdd(const For&, int) entrypoint, and adds dynamic lowering/inference in AtomicAddNode::InferLayout, building/fusing SIMT loops, collecting loop nests, inferring layouts, and invoking planner-based vectorization with optional predicate guards.

Changes

Cohort / File(s) Summary of Changes
Vectorize API & planner header
src/transform/atomicadd_vectorize.h
Adds AtomicAddVectorizePlanResult and AtomicAddVectorizePlanner (visitor + analyzer) with Plan(const For&, int). Changes VectorizeAtomicAdd signature to For VectorizeAtomicAdd(const For &for_node, int compute_capability). Declares planner helpers/state (vector_size, dynamic, condition).
Vectorize implementation
src/transform/atomicadd_vectorize.cc
Replaces broad includes with the new header; implements planner-based analysis using PostOrderVisit to derive vector_size, dynamic, and condition; introduces GetVectorizeSizeMax and UpdateVectorSize; separates planning and rewriting into planner → AtomicAddVectorizeRewriter pipeline and updates rewriter construction to accept a plan.
AtomicAdd lowering & layout inference
src/op/atomic_add.cc
Adds loop_parallel_transform_utils.h; replaces the prior arch-int helper with a local lambda; adds AtomicAddNode::InferLayout(const LayoutInferArgs&, InferLevel); restructures Lower to build/fuse SIMT loops, collect loop nests, compute/infer layouts, plan vectorization, and apply optional predicate guards during lowering.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Caller
    participant Lower as AtomicAddNode::Lower
    participant SIMT as SIMT Builder/Fuser
    participant Collector as AtomicLoopNestCollector
    participant Layout as Layout Infer
    participant Planner as AtomicAddVectorizePlanner
    participant Rewriter as AtomicAddVectorizeRewriter
    participant IR as Resulting IR

    Caller->>Lower: Lower(...)
    Lower->>SIMT: Build & fuse SIMT loop
    SIMT-->>Lower: Fused For
    Lower->>Collector: Collect loop nest & buffer indices
    Collector-->>Lower: Loop metadata
    Lower->>Layout: Compute layout / predicate (InferLayout)
    Layout-->>Lower: Layout + optional guard
    Lower->>Planner: Plan(fused_For, compute_capability)
    Note right of Planner #E6F2FF: Analyze AtomicAdd calls & dtypes → vector_size, dynamic, condition
    Planner-->>Rewriter: PlanResult
    Rewriter->>IR: Rewrite For (vectorized / guarded)
    IR-->>Lower: Lowered vectorized loop
    Lower-->>Caller: Final lowered IR
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

I hop through loops with careful care,
I count the atoms and split them fair.
I plan the stride and pick the size,
I rewrite the dance where vectors rise.
Carrots lined up, threads in a row — 🥕🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.86% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly highlights the core change—fixing a variable identification error in the AtomicAdd auto vectorization path—and directly reflects the main purpose of the PR without being overly broad or off-topic.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

Summary of Changes

Hello @yyttt6, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a bug in the atomicadd auto-vectorization pass, specifically addressing an error in how variables are identified within expressions. The core change involves refactoring the variable parsing logic to be more robust and accurate, ensuring correct vectorization. As a result of this fix, a previously failing and experimental test case related to tile_atomic_add has been re-enabled, confirming the resolution of the underlying issue.

Highlights

  • Refactored Variable Identification: The logic for identifying variables and their strides within AtomicAdd expressions has been significantly improved for robustness and accuracy during auto-vectorization.
  • Introduced ParseIndex Helper Function: A new lambda function, ParseIndex, was added to centralize and simplify the extraction of variables and their associated strides from complex expressions, making the parsing more reliable.
  • Re-enabled Test Case: A previously commented-out and failing test for tile_atomic_add functionality (test_tile_atomic_add) has been re-enabled, indicating the successful resolution of the underlying bug.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a bug in identifying variables for atomicadd auto-vectorization by introducing a more robust ParseIndex function. The changes are a definite improvement over the previous, more brittle implementation. I've identified a potential issue in how multiple AtomicAdd calls within a loop are handled, which could lead to incorrect behavior. My review includes a suggestion to make this logic more robust. Additionally, it's good to see that a previously failing test case has been re-enabled as part of this fix.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (5)
testing/python/language/test_tilelang_language_atomic_add.py (1)

375-376: Enable test: good; consider making it hardware-agnostic (float16) to avoid cc>=90 dependency

AtomicAddx4 for float32 is only selected when compute capability >= 90. On CI GPUs < 90 (e.g., A100 cc=80), this path may not vectorize and could cause flakiness for the tile-atomic path. Two options:

  • Portable: call with float16 so vectorization is available broadly.
  • Alternatively, gate/skip on device capability.

Apply this minimal change for portability:

-def test_tile_atomic_add():
-    run_tile_atomic_add(8, 128, 128, 32, 32)
+def test_tile_atomic_add():
+    run_tile_atomic_add(8, 128, 128, 32, 32, dtype="float16")

Also, consider removing or gating the debug prints in run_tile_atomic_add to keep test output clean (prints at Lines 58, 72, 73).

src/transform/atomicadd_vectorize.cc (4)

322-347: ParseIndex is too strict; accept const-expr strides and avoid false negatives

Requiring exactly one MulNode with a Var and an IntImm will miss common canonical forms:

  • Stride may be a foldable const expr or come via casts (not a bare IntImm).
  • Extra harmless multiplies like x*1 can appear pre-simplification.
  • You only need a unique var*const match; other non-relevant muls shouldn’t invalidate the parse.

Refine by simplifying first, using as_const_int, and relaxing the check to “exactly one legal var*const mul” regardless of other muls:

-  auto ParseIndex = [](const PrimExpr &idx, PrimExpr &var_out,
-                       int &stride_out) -> bool {
+  auto ParseIndex = [](const PrimExpr &idx, PrimExpr &var_out,
+                       int &stride_out) -> bool {
     int mul_count = 0, legal_mul_count = 0;
     stride_out = -1;
     var_out = PrimExpr();
-    PostOrderVisit(idx, [&](const ObjectRef &obj) {
+    // Simplify to eliminate x*1 and fold-able constants.
+    arith::Analyzer az;
+    PrimExpr sidx = az.Simplify(idx);
+    PostOrderVisit(sidx, [&](const ObjectRef &obj) {
       if (const MulNode *mul = obj.as<MulNode>()) {
         mul_count++;
-        const VarNode *var = nullptr;
-        const IntImmNode *imm = nullptr;
-        if ((var = mul->a.as<VarNode>()) && (imm = mul->b.as<IntImmNode>())) {
-          var_out = mul->a;
-          stride_out = imm->value;
-          legal_mul_count++;
-        } else if ((var = mul->b.as<VarNode>()) &&
-                   (imm = mul->a.as<IntImmNode>())) {
-          var_out = mul->b;
-          stride_out = imm->value;
-          legal_mul_count++;
-        }
+        const VarNode *var = nullptr;
+        const int64_t *c = nullptr;
+        if ((var = mul->a.as<VarNode>()) && (c = as_const_int(mul->b))) {
+          var_out = mul->a;
+          stride_out = static_cast<int>(*c);
+          legal_mul_count++;
+        } else if ((var = mul->b.as<VarNode>()) && (c = as_const_int(mul->a))) {
+          var_out = mul->b;
+          stride_out = static_cast<int>(*c);
+          legal_mul_count++;
+        }
       }
     });
-    if (mul_count == 1 && legal_mul_count == 1)
-      return true;
-    return false;
+    return legal_mul_count == 1;
   };

Note: this uses as_const_int and simplification. If not already available, include tvm/arith/analyzer.h (already included).


362-368: Accumulate vectorize_size_max across multiple AtomicAdd sites

If the loop body contains multiple AtomicAdd calls, you currently overwrite vectorize_size_max. Prefer taking the max to avoid under-vectorizing later calls.

-          DataType dtype = bufload->dtype;
-          vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
+          DataType dtype = bufload->dtype;
+          vectorize_size_max = std::max(
+              vectorize_size_max, GetVectorizeSizeMax(compute_capability, dtype));

You’ll need:

  • Add at top: #include

362-368: Guard against mis-identifying non-block vars as bx/by

ParseIndex will happily return any varconst (e.g., loop i1). Before accepting, assert the extracted vars are actual block indices (thread/block bindings) for safety, otherwise bail out. For example:

  • Verify var_out.as() is bound in thread_binding as blockIdx.{x,y} (or matches expected bx/by symbols in this pass’ context).
  • If that metadata isn’t available here, at least ensure both extracted vars differ and are not the loop var inside inner_for_.

This avoids rewriting with incorrect axes on more complex index expressions.


380-385: Extra sanity checks before rewriting

Before constructing the rewriter:

  • Ensure bx_var and by_var are Vars: if (!bx_var.as() || !by_var.as()) return for_node;
  • Optionally ensure stride_x > 0 && stride_y > 0.

This prevents emitting malformed truncdiv/truncmod expressions.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ec24561 and 5a02ea8.

📒 Files selected for processing (2)
  • src/transform/atomicadd_vectorize.cc (2 hunks)
  • testing/python/language/test_tilelang_language_atomic_add.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: format-check
  • GitHub Check: bot-task

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_atomic_add.py (1)

375-377: Remove noisy debug prints before re-enabling this test

Reactivating test_tile_atomic_add now exercises run_tile_atomic_add, which still contains print(kernel.get_kernel_source()) plus dumps of both 128×128 tensors. That’s hundreds of thousands of characters on every run and will swamp CI logs without adding assertion value. Please drop or gate those prints before merging.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5a02ea8 and 1255ccb.

📒 Files selected for processing (1)
  • testing/python/language/test_tilelang_language_atomic_add.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: bot-task

For VectorizeAtomicAdd(const For &for_node, const Var &thread_var,
const Range &thread_bounds, int compute_capability) {

auto ParseIndex = [](const PrimExpr &idx, PrimExpr &var_out,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add some comments for this function.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/atomic_add.cc (1)

35-46: Code duplication: GetArchInt is duplicated from src/target/utils.cc.

The GetArchInt function is duplicated from src/target/utils.cc (lines 17-25) with a different implementation. The local version includes a fallback to return 0 for non-sm_ architectures, while the version in src/target/utils.cc uses ICHECK to enforce the sm_ prefix.

Consider one of the following approaches:

  1. Preferred: Import and use the existing GetArchInt from src/target/utils.cc if the stricter validation is acceptable, or
  2. Update the version in src/target/utils.cc to include the fallback behavior and use it consistently across the codebase.

Apply this diff to use the existing function:

-static int GetArchInt(Target target) {
-  int arch_int = 0;
-  auto s = target->GetAttr<String>("arch");
-  ICHECK(s.defined());
-  std::string arch = s.value();
-  if (arch.rfind("sm_", 0) == 0) {
-    arch_int = std::stoi(arch.substr(3));
-  } else {
-    arch_int = 0;
-  }
-  return arch_int;
-}

And update the include at the top of the file if not already present:

 #include "../target/utils.h"
🧹 Nitpick comments (3)
src/op/atomic_add.cc (1)

372-372: Consider removing or adjusting the log level.

The LOG(INFO) statement prints the vectorized loop IR to the console. This may be useful during development but could be noisy in production.

Consider one of the following:

  1. Remove the log statement if it was added for debugging purposes only.
  2. Change to VLOG(1) or a higher verbosity level to reduce noise in production logs.
  3. If this is intentional diagnostic output, add a comment explaining why it's logged at INFO level.

Apply this diff to change to verbose logging:

-  LOG(INFO) << vectorized_thread_loop;
+  VLOG(1) << "Vectorized thread loop: " << vectorized_thread_loop;
src/transform/atomicadd_vectorize.cc (2)

33-55: Consider adding documentation for BufferIndiceSimplify.

The BufferIndiceSimplify class lacks documentation. Adding a brief comment explaining its purpose would improve maintainability.

Apply this diff to add documentation:

+/// \brief Simplifies buffer load and store indices using an analyzer.
+///
+/// This mutator visits BufferLoad and BufferStore nodes and simplifies
+/// their indices by applying the analyzer's Simplify method to each index.
 class BufferIndiceSimplify : public StmtExprMutator {

174-231: Consider adding documentation for the run() method.

The run() method implements complex loop transformation logic but lacks documentation explaining the transformation steps and the role of loop_layout and analyzer.

Apply this diff to add documentation:

+  /// \brief Transform and vectorize the for loop using the provided layout.
+  ///
+  /// \param for_node The original For loop to transform
+  /// \param loop_layout Fragment describing the loop layout transformation
+  /// \param analyzer Analyzer for simplifying indices and binding loop variables
+  /// \return Transformed and vectorized For loop
   For run(For for_node, const Fragment &loop_layout,
           arith::Analyzer *analyzer) {
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 202add4 and 6ceb5e0.

📒 Files selected for processing (3)
  • src/op/atomic_add.cc (1 hunks)
  • src/transform/atomicadd_vectorize.cc (6 hunks)
  • src/transform/atomicadd_vectorize.h (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/atomic_add.cc (2)
src/transform/atomicadd_vectorize.cc (2)
  • VectorizeAtomicAdd (308-343)
  • VectorizeAtomicAdd (308-310)
src/target/utils.cc (2)
  • GetArchInt (18-26)
  • GetArchInt (18-18)
src/transform/atomicadd_vectorize.cc (1)
src/transform/loop_partition.cc (5)
  • BufferIndiceSimplify (38-38)
  • LoopPragmaUnroll (201-205)
  • LoopPragmaUnroll (201-201)
  • PartitionLoop (61-105)
  • PartitionLoop (61-62)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: bot-task
  • GitHub Check: format-check
  • GitHub Check: format-check
🔇 Additional comments (5)
src/op/atomic_add.cc (1)

364-366: Call site correctly updated.

The call to VectorizeAtomicAdd now includes the required analyzer and loop_layout parameters, matching the updated API signature.

src/transform/atomicadd_vectorize.cc (3)

33-55: BufferIndiceSimplify implementation looks correct.

The mutator correctly simplifies indices for both BufferLoad and BufferStore nodes using the provided analyzer. The implementation properly uses CopyOnWrite() and maps indices through the analyzer's Simplify method.


308-343: VectorizeAtomicAdd implementation looks correct.

The function properly:

  1. Determines the maximum vectorization size based on compute capability and data type.
  2. Uses PartitionLoop to simplify the loop before planning.
  3. Plans the vectorization with the planner.
  4. Returns the original loop if vectorization is not beneficial (vectorize_hint == 1).
  5. Applies the vectorization rewrite with the provided analyzer and loop layout.

174-231: Verify vectorization across all loop levels and add tests. The run() method divides every transformed loop extent by vector_size_, not just the innermost; no existing tests cover this behavior—please confirm this is intentional and add multi-level loop vectorization test cases.

src/transform/atomicadd_vectorize.h (1)

17-19: Approve VectorizeAtomicAdd signature update
All call sites, including src/op/atomic_add.cc, have been updated to match the new parameters.

@yyttt6
Copy link
Contributor Author

yyttt6 commented Oct 3, 2025

I refactored part of the AtomicAdd auto vectorization code.
This change fixes the previous error of incorrect variable matching, and makes the AtomicAdd auto vectorization more robust and adaptable to a wider range of cases.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4526725 and 2dab926.

📒 Files selected for processing (4)
  • src/op/atomic_add.cc (1 hunks)
  • src/transform/atomicadd_vectorize.cc (6 hunks)
  • src/transform/atomicadd_vectorize.h (1 hunks)
  • tilelang/language/atomic.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/transform/atomicadd_vectorize.h
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/atomicadd_vectorize.cc (1)
src/transform/loop_partition.cc (5)
  • BufferIndiceSimplify (38-38)
  • LoopPragmaUnroll (201-205)
  • LoopPragmaUnroll (201-201)
  • PartitionLoop (61-105)
  • PartitionLoop (61-62)
src/op/atomic_add.cc (1)
src/transform/atomicadd_vectorize.cc (2)
  • VectorizeAtomicAdd (308-343)
  • VectorizeAtomicAdd (308-310)
tilelang/language/atomic.py (1)
tilelang/language/customize.py (1)
  • buffer_load_to_tile_region (46-68)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-amd
  • GitHub Check: format-check

Comment on lines +249 to 231
const BufferLoad dst_node =
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
const BufferLoad value_node =
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Downcast expects an ObjectRef, not a raw node pointer

Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()) doesn’t compile—the overload requires an ObjectRef, but as<...>() returns const BufferLoadNode*. Use the original PrimExpr instead (we already checked it’s a BufferLoad), e.g.:

-            const BufferLoad dst_node =
-                Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
-            const BufferLoad value_node =
-                Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
+            const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]);
+            const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]);

Without this change the file fails to build.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
const BufferLoad dst_node =
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
const BufferLoad value_node =
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]);
const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]);
🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 249 to 253, the code calls
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()) and similarly for
args[2], but Downcast expects an ObjectRef not a raw node pointer; replace the
.as<BufferLoadNode>() calls and pass the original PrimExprs (node->args[1] and
node->args[2]) directly to Downcast<BufferLoad>(), relying on the existing type
checks that confirmed these are BufferLoad instances so the Downcast will be
valid.

Comment on lines 193 to 211
src_extent = list(get_extent(value))
dst_extent = list(get_extent(dst))
legal = True

if (dst_extent is None and src_extent is None) or len(dst_extent) < len(src_extent):
legal = False
elif (dst_extent and src_extent):
if len(dst_extent) > len(src_extent):
dst_extent_dims = [x for x in dst_extent if x != 1]
if dst_extent_dims != src_extent:
legal = False
else:
if dst_extent != src_extent:
legal = False
else:
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(dst_extent, src_extent)
dst_extent = src_extent = extent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard get_extent() results before wrapping in list(...)

get_extent() still returns None for scalar PrimExpr inputs (e.g., atomic_add(dst, 1)), so list(get_extent(...)) raises a TypeError before we can fall back to the extern path. This regresses the scalar code path.

Please keep the raw result, check for None, and only convert to list when defined before the length/shape logic.

-    src_extent = list(get_extent(value))
-    dst_extent = list(get_extent(dst))
+    src_extent_raw = get_extent(value)
+    dst_extent_raw = get_extent(dst)
+    src_extent = list(src_extent_raw) if src_extent_raw is not None else None
+    dst_extent = list(dst_extent_raw) if dst_extent_raw is not None else None
     legal = True
 
-    if (dst_extent is None and src_extent is None) or len(dst_extent) < len(src_extent):
+    if dst_extent is None and src_extent is None:
+        legal = False
+    elif dst_extent is None:
+        dst_extent = [1] * len(src_extent)
+    elif src_extent is None:
+        src_extent = [1] * len(dst_extent)
+    elif len(dst_extent) < len(src_extent):
         legal = False

Make sure the remaining branches avoid len(None) as well.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
src_extent = list(get_extent(value))
dst_extent = list(get_extent(dst))
legal = True
if (dst_extent is None and src_extent is None) or len(dst_extent) < len(src_extent):
legal = False
elif (dst_extent and src_extent):
if len(dst_extent) > len(src_extent):
dst_extent_dims = [x for x in dst_extent if x != 1]
if dst_extent_dims != src_extent:
legal = False
else:
if dst_extent != src_extent:
legal = False
else:
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(dst_extent, src_extent)
dst_extent = src_extent = extent
src_extent_raw = get_extent(value)
dst_extent_raw = get_extent(dst)
src_extent = list(src_extent_raw) if src_extent_raw is not None else None
dst_extent = list(dst_extent_raw) if dst_extent_raw is not None else None
legal = True
if dst_extent is None and src_extent is None:
legal = False
elif dst_extent is None:
dst_extent = [1] * len(src_extent)
elif src_extent is None:
src_extent = [1] * len(dst_extent)
elif len(dst_extent) < len(src_extent):
legal = False
elif dst_extent and src_extent:
if len(dst_extent) > len(src_extent):
dst_extent_dims = [x for x in dst_extent if x != 1]
if dst_extent_dims != src_extent:
legal = False
else:
if dst_extent != src_extent:
legal = False
else:
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(dst_extent, src_extent)
dst_extent = src_extent = extent
🤖 Prompt for AI Agents
In tilelang/language/atomic.py around lines 193 to 211, guard the raw results of
get_extent() before calling list(...) to avoid TypeError on scalar PrimExprs:
assign src_extent_raw = get_extent(value) and dst_extent_raw = get_extent(dst),
check if each is None before converting to list; only call list(...) when the
raw extent is not None, and ensure every branch uses None-checks rather than
len(None) (e.g., when one side is None treat it as a scalar extent
[1]*len(other) or follow the extern-path fallback), then perform the same
dimension comparisons and the elementwise max logic operating on actual lists so
no branch ever calls len(...) or iterates over a None.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/transform/atomicadd_vectorize.cc (1)

52-205: Fix AtomicAdd rewrites: size guard and correct Downcast usage.

Two problems here break compilation and correctness:

  1. We only check args.size() >= 2, yet we always read args[2]; AtomicAdd calls must therefore guard for >= 3 before touching the value operand.
  2. Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()) is invalid—the overload expects an ObjectRef, not a raw node pointer. This reintroduces the build failure from the previous review.

Please tighten the guards and pass the original PrimExpr to Downcast:

-  if (node->op == builtin::call_extern() && node->args.size() >= 2) {
+  if (node->op == builtin::call_extern() && node->args.size() >= 3) {
     if (const auto *func_name = node->args[0].as<StringImmNode>()) {
       if (func_name->value == "AtomicAdd") {
         const BufferLoadNode *temp_dst_node =
             node->args[1].as<BufferLoadNode>();
         const BufferLoadNode *temp_value_node =
             node->args[2].as<BufferLoadNode>();
         if (!temp_dst_node || !temp_value_node) {
           return StmtExprMutator::VisitExpr_(node);
         }
-        const BufferLoad dst_node =
-            Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
-        const BufferLoad value_node =
-            Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
+        const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]);
+        const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]);
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between badf9c1 and c317f26.

📒 Files selected for processing (3)
  • src/op/atomic_add.cc (2 hunks)
  • src/transform/atomicadd_vectorize.cc (6 hunks)
  • src/transform/atomicadd_vectorize.h (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/atomicadd_vectorize.h (2)
src/transform/loop_vectorize.h (2)
  • tvm (31-49)
  • tl (32-48)
src/transform/atomicadd_vectorize.cc (17)
  • VectorizeAtomicAdd (236-246)
  • VectorizeAtomicAdd (236-236)
  • AtomicAddVectorizePlanner (15-15)
  • Plan (17-44)
  • Plan (18-18)
  • node (162-184)
  • node (162-162)
  • node (186-228)
  • node (186-186)
  • VisitStmt_ (46-49)
  • VisitStmt_ (46-46)
  • VisitExpr_ (51-71)
  • VisitExpr_ (51-51)
  • GetVectorizeSizeMax (73-85)
  • GetVectorizeSizeMax (73-74)
  • UpdateVectorSize (87-127)
  • UpdateVectorSize (87-88)
src/transform/atomicadd_vectorize.cc (2)
src/transform/atomicadd_vectorize.h (1)
  • AtomicAddVectorizePlanner (36-55)
src/transform/loop_vectorize.cc (4)
  • indices (157-189)
  • indices (157-157)
  • IndiceCanVectorize (257-298)
  • IndiceCanVectorize (257-259)
src/op/atomic_add.cc (6)
src/op/parallel.cc (8)
  • Lower (184-187)
  • Lower (184-185)
  • VisitStmt_ (130-146)
  • VisitStmt_ (130-130)
  • VisitStmt_ (148-160)
  • VisitStmt_ (148-148)
  • VisitExpr_ (162-173)
  • VisitExpr_ (162-162)
src/op/copy.cc (6)
  • Lower (791-823)
  • Lower (791-791)
  • Lower (1776-1898)
  • Lower (1776-1777)
  • MakeSIMTLoop (299-344)
  • MakeSIMTLoop (299-299)
src/op/fill.cc (4)
  • Lower (171-206)
  • Lower (171-171)
  • MakeSIMTLoop (136-151)
  • MakeSIMTLoop (136-136)
src/op/reduce.cc (4)
  • Lower (152-318)
  • Lower (152-152)
  • Lower (413-437)
  • Lower (413-413)
src/target/utils.cc (2)
  • GetArchInt (18-26)
  • GetArchInt (18-18)
src/transform/atomicadd_vectorize.cc (6)
  • VisitStmt_ (46-49)
  • VisitStmt_ (46-46)
  • VisitExpr_ (51-71)
  • VisitExpr_ (51-51)
  • VectorizeAtomicAdd (236-246)
  • VectorizeAtomicAdd (236-236)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: format-check

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c317f26 and 12e437c.

📒 Files selected for processing (2)
  • src/op/atomic_add.cc (3 hunks)
  • src/transform/atomicadd_vectorize.cc (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/atomic_add.cc (5)
src/op/parallel.cc (8)
  • InferLayout (212-527)
  • InferLayout (212-213)
  • VisitStmt_ (130-146)
  • VisitStmt_ (130-130)
  • VisitStmt_ (148-160)
  • VisitStmt_ (148-148)
  • VisitExpr_ (162-173)
  • VisitExpr_ (162-162)
src/op/copy.cc (6)
  • InferLayout (399-498)
  • InferLayout (399-400)
  • InferLayout (1955-1958)
  • InferLayout (1955-1956)
  • MakeSIMTLoop (299-344)
  • MakeSIMTLoop (299-299)
src/op/parallel.h (1)
  • ParallelOp (151-158)
src/target/utils.cc (2)
  • GetArchInt (18-26)
  • GetArchInt (18-18)
src/transform/atomicadd_vectorize.cc (6)
  • VisitStmt_ (45-48)
  • VisitStmt_ (45-45)
  • VisitExpr_ (50-70)
  • VisitExpr_ (50-50)
  • VectorizeAtomicAdd (235-244)
  • VectorizeAtomicAdd (235-235)
src/transform/atomicadd_vectorize.cc (2)
src/transform/atomicadd_vectorize.h (1)
  • AtomicAddVectorizePlanner (36-55)
src/transform/loop_vectorize.cc (4)
  • indices (157-189)
  • indices (157-157)
  • IndiceCanVectorize (257-298)
  • IndiceCanVectorize (257-259)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-metal

Comment on lines +401 to +383
auto GetArchInt = [&](const Target &tgt) -> int {
int arch_int = 0;
if (auto s = tgt->GetAttr<String>("arch")) {
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0)
arch_int = std::stoi(arch.substr(3));
}
return arch_int;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Drop the shadowing GetArchInt lambda.

We already have a file-scope GetArchInt(Target) (Lines 37-48). Redefining an identical lambda here is redundant, risks divergence, and silently bypasses the shared helper (e.g. any future validation fixes). Call the existing function instead of shadowing it locally.

🤖 Prompt for AI Agents
In src/op/atomic_add.cc around lines 401 to 409, there is a locally defined
GetArchInt lambda that shadows an existing file-scope GetArchInt(Target) (lines
37-48); remove this redundant lambda and replace any uses in this scope with a
direct call to the file-scope GetArchInt(tgt) helper so the shared
implementation (and any future validations) are preserved; ensure the lambda
definition is deleted and all call sites use GetArchInt(tgt) without adding new
duplicates.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
src/op/atomic_add.cc (2)

375-383: Drop the shadowing GetArchInt lambda.

This lambda duplicates the file-scope GetArchInt(Target) function from src/target/utils.h (already included on Line 13). Using a local shadow risks divergence and bypasses any future validation fixes in the shared implementation.

Based on past review comments.

Apply this diff:

-  auto GetArchInt = [&](const Target &tgt) -> int {
-    int arch_int = 0;
-    if (auto s = tgt->GetAttr<String>("arch")) {
-      std::string arch = s.value();
-      if (arch.rfind("sm_", 0) == 0)
-        arch_int = std::stoi(arch.substr(3));
-    }
-    return arch_int;
-  };

And update line 455 and 506 to call GetArchInt(target) directly.


490-493: Propagate the dynamic predicate.

The planner's dynamic predicate is captured here but never used. When plan.dynamic is true and plan.condition is defined, the final vectorized_thread_loop should be wrapped in a guard (e.g., IfThenElse(pred, vectorized_thread_loop, thread_loop)) so the vectorized path executes only when the condition holds.

Based on past review comments.

If you intend to propagate the predicate, apply a diff similar to:

   auto vectorized_thread_loop =
       VectorizeAtomicAdd(thread_loop, GetArchInt(target));
+  if (ret.predicate.defined()) {
+    return IfThenElse(ret.predicate.value(), vectorized_thread_loop, thread_loop);
+  }
   return vectorized_thread_loop;
src/transform/atomicadd_vectorize.cc (1)

227-230: Pass the original PrimExprs to Downcast.

Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()) attempts to downcast a raw pointer, which fails to compile. Since Lines 220-224 already verify the args are BufferLoadNode*, pass the original PrimExpr objects directly to Downcast.

Based on past review comments.

Apply this diff:

-            const BufferLoad dst_node =
-                Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
-            const BufferLoad value_node =
-                Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
+            const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]);
+            const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]);
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 12e437c and 952d3de.

📒 Files selected for processing (3)
  • src/op/atomic_add.cc (4 hunks)
  • src/transform/atomicadd_vectorize.cc (6 hunks)
  • src/transform/atomicadd_vectorize.h (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/atomicadd_vectorize.cc (2)
src/transform/atomicadd_vectorize.h (1)
  • AtomicAddVectorizePlanner (36-55)
src/transform/loop_vectorize.cc (4)
  • indices (157-189)
  • indices (157-157)
  • IndiceCanVectorize (257-298)
  • IndiceCanVectorize (257-259)
src/op/atomic_add.cc (3)
src/transform/atomicadd_vectorize.cc (10)
  • node (188-210)
  • node (188-188)
  • node (212-254)
  • node (212-212)
  • VisitStmt_ (72-75)
  • VisitStmt_ (72-72)
  • VisitExpr_ (77-97)
  • VisitExpr_ (77-77)
  • VectorizeAtomicAdd (262-271)
  • VectorizeAtomicAdd (262-262)
src/op/parallel.cc (8)
  • InferLayout (212-527)
  • InferLayout (212-213)
  • VisitStmt_ (130-146)
  • VisitStmt_ (130-130)
  • VisitStmt_ (148-160)
  • VisitStmt_ (148-148)
  • VisitExpr_ (162-173)
  • VisitExpr_ (162-162)
src/target/utils.cc (2)
  • GetArchInt (18-26)
  • GetArchInt (18-18)
src/transform/atomicadd_vectorize.h (2)
src/transform/loop_vectorize.h (2)
  • tvm (31-49)
  • tl (32-48)
src/transform/atomicadd_vectorize.cc (17)
  • VectorizeAtomicAdd (262-271)
  • VectorizeAtomicAdd (262-262)
  • AtomicAddVectorizePlanner (15-15)
  • Plan (17-70)
  • Plan (18-18)
  • node (188-210)
  • node (188-188)
  • node (212-254)
  • node (212-212)
  • VisitStmt_ (72-75)
  • VisitStmt_ (72-72)
  • VisitExpr_ (77-97)
  • VisitExpr_ (77-77)
  • GetVectorizeSizeMax (99-111)
  • GetVectorizeSizeMax (99-100)
  • UpdateVectorSize (113-153)
  • UpdateVectorSize (113-114)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-amd
  • GitHub Check: build-test-metal
🔇 Additional comments (7)
src/op/atomic_add.cc (2)

292-331: LGTM!

The lazy ParallelOp construction and layout validation logic is well-structured. The diagnostic output in the ICHECK will help users understand layout mismatches.


59-59: Verify Python AtomicAdd arg ordering
No C++ callers require updates; confirm Python wrappers (e.g., tvm/tir/op/atomic_add.py) and tests now use args[2] for coalesced_width instead of args[3].

src/transform/atomicadd_vectorize.cc (5)

17-70: LGTM!

The planner's PostOrderVisit approach cleanly separates the max vector size determination phase from the buffer access analysis phase. The early exit when vectorize_size_max <= 1 avoids unnecessary work.


72-97: LGTM!

The visitor pattern correctly tracks the innermost loop and processes AtomicAdd buffer accesses to refine the vector size.


99-111: LGTM!

The compute capability and dtype-based vector size limits correctly reflect hardware constraints for atomic operations.


113-153: LGTM!

The vector size refinement logic correctly uses GCD operations to find alignment constraints and the IndiceCanVectorize check to verify vectorization safety. The dynamic fallback path properly captures the runtime predicate.


262-271: LGTM!

The entry point cleanly orchestrates the planner-driven vectorization flow. The early exit avoids unnecessary rewriting when vectorization isn't beneficial.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/transform/atomicadd_vectorize.h (1)

48-50: Remove unused planner state.

has_nonlocal_memory_access_ is introduced here but never referenced in atomicadd_vectorize.cc, so it will trip -Wunused-private-field (and become a hard error under -Werror). Please drop it or wire it into the implementation.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 952d3de and 907bbed.

📒 Files selected for processing (1)
  • src/transform/atomicadd_vectorize.h (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.h (1)
src/transform/atomicadd_vectorize.cc (17)
  • VectorizeAtomicAdd (262-271)
  • VectorizeAtomicAdd (262-262)
  • AtomicAddVectorizePlanner (15-15)
  • Plan (17-70)
  • Plan (18-18)
  • node (188-210)
  • node (188-188)
  • node (212-254)
  • node (212-212)
  • VisitStmt_ (72-75)
  • VisitStmt_ (72-72)
  • VisitExpr_ (77-97)
  • VisitExpr_ (77-77)
  • GetVectorizeSizeMax (99-111)
  • GetVectorizeSizeMax (99-100)
  • UpdateVectorSize (113-153)
  • UpdateVectorSize (113-114)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-metal

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
src/op/atomic_add.cc (2)

372-380: Drop the shadowing GetArchInt lambda.

A file-scope GetArchInt(Target) already exists in src/target/utils.cc. Redefining an identical lambda here is redundant, risks divergence if the shared helper is updated with validation fixes, and bypasses the centralized implementation.

Replace the lambda with a call to the existing file-scope GetArchInt:

-  auto GetArchInt = [&](const Target &tgt) -> int {
-    int arch_int = 0;
-    if (auto s = tgt->GetAttr<String>("arch")) {
-      std::string arch = s.value();
-      if (arch.rfind("sm_", 0) == 0)
-        arch_int = std::stoi(arch.substr(3));
-    }
-    return arch_int;
-  };

And update the call sites (lines 452, 503) to use GetArchInt(target) directly.


487-504: Propagate planner predicate for dynamic vectorization.

AtomicAddVectorizePlanner::Plan can return dynamic=true with a condition guard. The predicate is captured at lines 487-490 but never used. When dynamic is true, the vectorized body must be wrapped with the guard before execution; otherwise the dynamic plan executes incorrectly.

Apply this diff to wrap the vectorized loop with the predicate when present:

   auto ret = AtomicAddInferLayout(transformed_loop,
                                   {T.target, T.thread_bounds, T.layout_map,
                                    analyzer, false, T.buffer_remap});
   Fragment loop_layout = ret.loop_layout;
   auto thread_loop =
       PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout);
   auto vectorized_thread_loop =
       VectorizeAtomicAdd(thread_loop, GetArchInt(target));
-  return vectorized_thread_loop;
+  if (ret.predicate.defined()) {
+    return IfThenElse(ret.predicate.value(), vectorized_thread_loop);
+  }
+  return vectorized_thread_loop;
 }
🧹 Nitpick comments (1)
src/op/atomic_add.cc (1)

398-409: Consider adding consistency checks for buffer indices.

Unlike ParallelLoopNestVisitor (see src/op/parallel.cc lines 147-172), AtomicLoopNestCollector does not verify that the same buffer is accessed with structurally equal indices across multiple access sites. If a buffer has inconsistent access patterns, the collector silently overwrites the indices, potentially masking layout inference bugs.

Apply this diff to add consistency validation:

 void VisitStmt_(const BufferStoreNode *op) final {
   if (op->buffer.scope() == "local.fragment") {
+    if (indice_map.count(op->buffer)) {
+      ICHECK(StructuralEqual()(indice_map[op->buffer], op->indices))
+          << op->buffer << ": " << op->indices << " vs "
+          << indice_map[op->buffer];
+    }
     indice_map.Set(op->buffer, op->indices);
     writes.insert(op->buffer);
   }
   StmtExprVisitor::VisitStmt_(op);
 }
 void VisitExpr_(const BufferLoadNode *op) final {
   if (op->buffer.scope() == "local.fragment") {
+    if (indice_map.count(op->buffer)) {
+      ICHECK(StructuralEqual()(indice_map[op->buffer], op->indices))
+          << op->buffer << ": " << op->indices << " vs "
+          << indice_map[op->buffer];
+    }
     indice_map.Set(op->buffer, op->indices);
   }
   StmtExprVisitor::VisitExpr_(op);
 }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b3ad3a9 and 9053b42.

📒 Files selected for processing (1)
  • src/op/atomic_add.cc (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/atomic_add.cc (3)
src/op/parallel.cc (8)
  • InferLayout (212-527)
  • InferLayout (212-213)
  • VisitStmt_ (130-146)
  • VisitStmt_ (130-130)
  • VisitStmt_ (148-160)
  • VisitStmt_ (148-148)
  • VisitExpr_ (162-173)
  • VisitExpr_ (162-162)
src/target/utils.cc (2)
  • GetArchInt (18-26)
  • GetArchInt (18-18)
src/transform/atomicadd_vectorize.cc (6)
  • VisitStmt_ (72-75)
  • VisitStmt_ (72-72)
  • VisitExpr_ (77-97)
  • VisitExpr_ (77-77)
  • VectorizeAtomicAdd (262-271)
  • VectorizeAtomicAdd (262-262)
🔇 Additional comments (3)
src/op/atomic_add.cc (3)

16-16: LGTM!

The include for loop_parallel_transform_utils.h is correctly added to support the new ParallelLoopTransformer::Substitute call introduced in the refactored Lower method.


289-328: LGTM!

The InferLayout method correctly implements lazy initialization of the ParallelOp, validates fragment layout consistency between source and destination buffers, and properly delegates to the underlying ParallelOp for layout inference.


413-431: LGTM!

The ComputeLoopLayoutFromBuffer lambda correctly derives a fragment layout by computing the forward thread expression from buffer indices and binding to the thread range. The AtomicInferResult struct appropriately encapsulates the layout and optional predicate results from the inference process.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
src/op/atomic_add.cc (2)

378-386: Drop the redundant local GetArchInt lambda.

This local lambda duplicates the file-scope GetArchInt(Target) function already available in ../target/utils.h (included at line 13). The duplication creates maintenance risk and bypasses any future improvements to the shared implementation.

Apply this diff to remove the lambda and use the file-scope function:

-  auto GetArchInt = [&](const Target &tgt) -> int {
-    int arch_int = 0;
-    if (auto s = tgt->GetAttr<String>("arch")) {
-      std::string arch = s.value();
-      if (arch.rfind("sm_", 0) == 0)
-        arch_int = std::stoi(arch.substr(3));
-    }
-    return arch_int;
-  };

Then update the call sites at lines 458 and 509 to use the file-scope GetArchInt(target) directly (note: the function is already in scope via the include at line 13).


502-510: Propagate the planner predicate for dynamic vectorization.

AtomicAddInferLayout captures a predicate from the planner when dynamic=true (lines 493-496) but this predicate is never used. When dynamic vectorization is planned, the vectorized body should be wrapped with the predicate guard to ensure correct execution.

Apply this diff to wrap the vectorized loop with the predicate when present:

   auto ret = AtomicAddInferLayout(transformed_loop,
                                   {T.target, T.thread_bounds, T.layout_map,
                                    analyzer, false, T.buffer_remap});
   Fragment loop_layout = ret.loop_layout;
   auto thread_loop =
       PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout);
   auto vectorized_thread_loop =
       VectorizeAtomicAdd(thread_loop, GetArchInt(target));
-  return vectorized_thread_loop;
+  if (ret.predicate.defined()) {
+    return IfThenElse(ret.predicate.value(), vectorized_thread_loop);
+  } else {
+    return vectorized_thread_loop;
+  }
🧹 Nitpick comments (1)
src/op/atomic_add.cc (1)

388-417: Add index consistency validation in AtomicLoopNestCollector.

The collector sets indice_map without validating that repeated accesses to the same buffer use structurally equal indices. This differs from ParallelLoopNestVisitor (in src/op/parallel.cc lines 147-172), which validates consistency with ICHECK(StructuralEqual()(...). Silently overwriting indices can hide bugs where the same buffer is accessed with different patterns.

Apply this diff to add validation:

     void VisitStmt_(const BufferStoreNode *op) final {
       if (op->buffer.scope() == "local.fragment") {
+        if (indice_map.count(op->buffer)) {
+          ICHECK(StructuralEqual()(indice_map[op->buffer], op->indices))
+              << op->buffer << ": " << op->indices << " and "
+              << indice_map[op->buffer];
+        }
         indice_map.Set(op->buffer, op->indices);
         writes.insert(op->buffer);
       }
       StmtExprVisitor::VisitStmt_(op);
     }
     void VisitExpr_(const BufferLoadNode *op) final {
       if (op->buffer.scope() == "local.fragment") {
+        if (indice_map.count(op->buffer)) {
+          ICHECK(StructuralEqual()(indice_map[op->buffer], op->indices))
+              << op->buffer << ": " << op->indices << " and "
+              << indice_map[op->buffer];
+        }
         indice_map.Set(op->buffer, op->indices);
       }
       StmtExprVisitor::VisitExpr_(op);
     }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9053b42 and d96ac34.

📒 Files selected for processing (1)
  • src/op/atomic_add.cc (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/atomic_add.cc (3)
src/op/parallel.cc (8)
  • InferLayout (212-527)
  • InferLayout (212-213)
  • VisitStmt_ (130-146)
  • VisitStmt_ (130-130)
  • VisitStmt_ (148-160)
  • VisitStmt_ (148-148)
  • VisitExpr_ (162-173)
  • VisitExpr_ (162-162)
src/target/utils.cc (2)
  • GetArchInt (18-26)
  • GetArchInt (18-18)
src/transform/atomicadd_vectorize.cc (6)
  • VisitStmt_ (72-75)
  • VisitStmt_ (72-72)
  • VisitExpr_ (77-97)
  • VisitExpr_ (77-77)
  • VectorizeAtomicAdd (262-271)
  • VectorizeAtomicAdd (262-262)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-metal
🔇 Additional comments (5)
src/op/atomic_add.cc (5)

16-16: LGTM!

The new include is necessary for ParallelLoopTransformer used in the refactored lowering pipeline.


295-334: LGTM!

The InferLayout implementation properly:

  • Lazily constructs the ParallelOp when needed
  • Validates fragment layout consistency between src and dst
  • Delegates to the underlying ParallelOp for layout inference

375-376: LGTM!

The ParallelLoopTransformer::Substitute call appropriately transforms the fused loop before further processing.


419-432: LGTM!

The ComputeLoopLayoutFromBuffer lambda correctly computes a Fragment layout from buffer indices and thread bounds.


439-500: The layout inference logic is sound, but predicate handling needs attention.

The AtomicAddInferLayout implementation correctly:

  • Collects loop nest information
  • Selects the best buffer for layout inference
  • Uses the planner to determine vector size
  • Adjusts vector size based on coalesced_width annotation and divisibility constraints
  • Computes the final loop layout

However, the predicate captured at lines 493-496 is returned but not used by the caller (see next comment on lines 502-510).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
src/op/atomic_add.cc (2)

412-420: Remove the redundant GetArchInt lambda.

A file-scope GetArchInt(Target) function already exists (imported via #include "../target/utils.h" at line 13, defined in /src/target/utils.cc). This local lambda duplicates that logic and shadows the shared implementation. Use the existing function directly instead.

Apply this diff:

-  auto GetArchInt = [&](const Target &tgt) -> int {
-    int arch_int = 0;
-    if (auto s = tgt->GetAttr<String>("arch")) {
-      std::string arch = s.value();
-      if (arch.rfind("sm_", 0) == 0)
-        arch_int = std::stoi(arch.substr(3));
-    }
-    return arch_int;
-  };

Then update line 492 and 543 to call the file-scope function:

-    int sm = GetArchInt(target);
+    int sm = tl::GetArchInt(target);
-      VectorizeAtomicAdd(thread_loop, GetArchInt(target));
+      VectorizeAtomicAdd(thread_loop, tl::GetArchInt(target));

536-544: Missing predicate guard for dynamic vectorization.

The planner's predicate (captured at line 533) is never used. When plan.dynamic is true, the vectorized body must be wrapped with the predicate guard to ensure correctness. Currently, line 544 returns the vectorized loop unconditionally, which can produce incorrect results for dynamic plans.

Apply this diff to wrap the vectorized loop when a predicate exists:

   auto ret = AtomicAddInferLayout(transformed_loop,
                                   {T.target, T.thread_bounds, T.layout_map,
                                    analyzer, false, T.buffer_remap});
   Fragment loop_layout = ret.loop_layout;
   auto thread_loop =
       PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout);
   auto vectorized_thread_loop =
       VectorizeAtomicAdd(thread_loop, tl::GetArchInt(target));
-  return vectorized_thread_loop;
+  if (ret.predicate.defined()) {
+    return IfThenElse(ret.predicate.value(), vectorized_thread_loop);
+  }
+  return vectorized_thread_loop;
 }
🧹 Nitpick comments (2)
src/op/atomic_add.cc (2)

477-490: Consider renaming read_src for clarity.

The variable read_src (line 477) doesn't verify whether the buffer is read-only; it simply selects the buffer with the highest rank from indice_map. Consider renaming it to layout_src or infer_src to better reflect its purpose.


523-524: Remove unnecessary remapped variable.

Line 523 assigns loop to remapped without any transformation. This appears to be leftover code from a refactoring. Simplify by using loop directly at line 524.

Apply this diff:

-      const For &remapped = loop;
-      loop_layout = PlanLoopPartition(remapped, vec, args.thread_bounds);
+      loop_layout = PlanLoopPartition(loop, vec, args.thread_bounds);
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d96ac34 and c31dfcd.

📒 Files selected for processing (1)
  • src/op/atomic_add.cc (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/atomic_add.cc (4)
src/op/copy.cc (6)
  • InferLayout (399-498)
  • InferLayout (399-400)
  • InferLayout (1962-1965)
  • InferLayout (1962-1963)
  • MakeSIMTLoop (299-344)
  • MakeSIMTLoop (299-299)
src/op/parallel.cc (8)
  • InferLayout (212-527)
  • InferLayout (212-213)
  • VisitStmt_ (130-146)
  • VisitStmt_ (130-130)
  • VisitStmt_ (148-160)
  • VisitStmt_ (148-148)
  • VisitExpr_ (162-173)
  • VisitExpr_ (162-162)
src/target/utils.cc (2)
  • GetArchInt (18-26)
  • GetArchInt (18-18)
src/transform/atomicadd_vectorize.cc (6)
  • VisitStmt_ (72-75)
  • VisitStmt_ (72-72)
  • VisitExpr_ (77-97)
  • VisitExpr_ (77-77)
  • VectorizeAtomicAdd (262-271)
  • VectorizeAtomicAdd (262-262)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-metal
🔇 Additional comments (3)
src/op/atomic_add.cc (3)

307-346: LGTM: Well-structured layout inference method.

The new InferLayout method follows good patterns:

  • Lazy initialization of the ParallelOp cache
  • Proper validation of fragment layout consistency between src and dst
  • Clean delegation to the underlying ParallelOp

422-451: LGTM: Clean visitor implementation.

AtomicLoopNestCollector follows the established visitor pattern from ParallelLoopNestVisitor in /src/op/parallel.cc, appropriately simplified for atomic add operations.


453-466: LGTM: Correct layout computation.

The ComputeLoopLayoutFromBuffer lambda correctly derives loop layout from buffer layout using ForwardThread and thread range binding, consistent with the pattern in /src/op/parallel.cc.

@LeiWang1999 LeiWang1999 merged commit 340bfc5 into tile-ai:main Oct 13, 2025
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants